from tensorflow.python.framework import graph_util
from config import Config
from core.resnet import ResNet
from core.lstm import Lstm
from core.loss import Loss
from core.datalayer import DataLayers
from core.proposallayer import ProposalLayer
from core.textdetector import TextDetector
from core.utils import Utils
from sys import argv
import tensorflow as tf
import numpy as np
import cv2
import os

batch_size = 1

def show():
    print('python test.py [image] [dst_image]')

class TextDetect:
    def __init__(self,pb_file_path):
        self.sess = tf.Session()
        output_graph_def = tf.GraphDef()
        with open(pb_file_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            tf.import_graph_def(output_graph_def, name="")
        init = tf.global_variables_initializer()
        self.sess.run(init)

    def detect(self,imgs):
        for img in imgs:
            image_info = {}
            image_info['image_data'] = img
            width = img.shape[1]
            height = img.shape[0]
            image_info['width'] = width
            image_info['height'] = height
            # 根据高和宽计算feature_map的大小
            if (height % 16 == 0):
                feature_map_h = height // 16
            else:
                feature_map_h = height // 16 + 1
            if (width % 16 == 0):
                feature_map_w = width // 16
            else:
                feature_map_w = width // 16 + 1
            image_info['featuremap_h'] = feature_map_h
            image_info['featuremap_w'] = feature_map_w

            input_image = self.sess.graph.get_tensor_by_name("image:0")
            output_cls = self.sess.graph.get_tensor_by_name("fc_rpn_cls/output:0")
            output_pred = self.sess.graph.get_tensor_by_name("fc_rpn_pred/output:0")

            output_ls = self.sess.run([output_cls, output_pred],
                                 feed_dict={input_image: img.reshape(-1, height, width, 3)})
            cls = np.reshape(output_ls[0], (-1, 2))
            box = np.reshape(output_ls[1], (-1, 4))

            proposals = ProposalLayer.c_generate_proposals(cls, box, 1, image_info, 16)
            proposals = proposals[proposals[:, 1] >= 0.9]
            boxes = TextDetector.c_detect(proposals[:, 2:], proposals[:, 1], [height, width])
            Utils.draw_boxes(img, ls[2], boxes)

if __name__ == "__main__":

    ls = argv
    if(len(ls)!=3):
        show()
        exit()

    img = cv2.imread(ls[1])
    im_scale = 1000 / img.shape[0]
    resizeh = 1000
    resizew = int(img.shape[1] * im_scale)
    img = cv2.resize(img, (resizew, resizeh))

    imgs = []
    imgs.append(img)

    td = TextDetect('model/ctpn.pb')
    td.detect(imgs)










